﻿#pragma kernel ParallelTransport
#pragma kernel Decimate
#pragma kernel ChaikinSmooth

#include "PathFrame.cginc"

struct smootherPathData
{
    uint smoothing;
    float decimation;
    float twist;
    float restLength;
    float smoothLength;
    bool usesOrientedParticles;
};

StructuredBuffer<float4> renderablePositions;
StructuredBuffer<quaternion> renderableOrientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> colors;

RWStructuredBuffer<smootherPathData> pathData;
StructuredBuffer<int> particleIndices;

RWStructuredBuffer<pathFrame> pathFrames;
StructuredBuffer<int> frameOffsets;
RWStructuredBuffer<int> decimatedFrameCounts;

RWStructuredBuffer<pathFrame> smoothFrames;
StructuredBuffer<int> smoothFrameOffsets;
RWStructuredBuffer<int> smoothFrameCounts;

// Variables set from the CPU
uint chunkCount;

void PathFrameFromParticle(inout pathFrame frame, int particleIndex, bool useOrientedParticles, bool interpolateOrientation = false)
{
    // Update current frame values from particles:
    frame.position = renderablePositions[particleIndex].xyz;
    frame.thickness = principalRadii[particleIndex][0];
    frame.color = colors[particleIndex];

    // Use particle orientation if possible:
    if (useOrientedParticles)
    {
        quaternion current = renderableOrientations[particleIndex];
        quaternion previous = renderableOrientations[max(0, particleIndex - 1)];
        float4x4 average = q_toMatrix(interpolateOrientation ? q_slerp(current, previous, 0.5f) : current);
        frame.normal = average._m01_m11_m21;
        frame.binormal = average._m00_m10_m20;
        frame.tangent = average._m02_m12_m22;
    }
}


[numthreads(128, 1, 1)]
void ParallelTransport (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= chunkCount) return;

    pathFrame nextFrame;
    pathFrame currFrame; 
    pathFrame prevFrame;

    nextFrame.Reset();
    currFrame.Reset();
    prevFrame.Reset();

    int firstIndex = i > 0 ? frameOffsets[i - 1] : 0;
    int frameCount = frameOffsets[i] - firstIndex;

    // initialize current and previous frame:
    PathFrameFromParticle(currFrame, particleIndices[firstIndex], pathData[i].usesOrientedParticles, false);
    prevFrame = currFrame;

    // parallel transport:
    for (int m = 1; m <= frameCount; ++m)
    {
        int index = firstIndex + min(m, frameCount - 1);
        int pIndex = particleIndices[index];

        // generate curve frame from particle:
        PathFrameFromParticle(nextFrame, pIndex, pathData[i].usesOrientedParticles);

        if (pathData[i].usesOrientedParticles)
        {
            // copy frame directly.
            prevFrame = currFrame;
        }
        else
        {
            // perform parallel transport, using forward / backward average to calculate tangent.
            // if the average is too small, reuse the previous frame tangent.
            currFrame.tangent = normalizesafe((currFrame.position - prevFrame.position) + 
                                              (nextFrame.position - currFrame.position), prevFrame.tangent);
            prevFrame.Transport(currFrame, pathData[i].twist);
        }

        // advance current frame:
        currFrame = nextFrame;
        pathFrames[firstIndex + m - 1] = prevFrame;
    }
}

[numthreads(128, 1, 1)]
void Decimate (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= chunkCount) return;

    int firstInputIndex = i > 0 ? frameOffsets[i - 1] : 0;
    int inputFrameCount = frameOffsets[i] - firstInputIndex;

    // no decimation, no work to do, just return:
    if (pathData[i].decimation < 0.00001f || inputFrameCount < 3)
    {
        decimatedFrameCounts[i] = inputFrameCount;
        return;
    }

    float scaledThreshold = pathData[i].decimation * pathData[i].decimation * 0.01f;

    int start = 0;
    int end = inputFrameCount - 1;
    decimatedFrameCounts[i] = 0;

    while (start < end)
    {
        // add starting point:
        pathFrames[firstInputIndex + decimatedFrameCounts[i]++] = pathFrames[firstInputIndex + start];

        int newEnd = end;

        while (true)
        {
            int maxDistanceIndex = 0;
            float maxDistance = 0;
            float mu;

            // find the point that's furthest away from the current segment:
            for (int j = start + 1; j < newEnd; j++)
            {
                float3 nearest = NearestPointOnEdge(pathFrames[firstInputIndex + start].position,
                                                    pathFrames[firstInputIndex + newEnd].position,
                                                    pathFrames[firstInputIndex + j].position, mu);

                float3 delta = nearest - pathFrames[firstInputIndex + j].position;
                float d = dot(delta,delta);

                if (d > maxDistance)
                {
                    maxDistanceIndex = j;
                    maxDistance = d;
                }
            }

            if (maxDistance <= scaledThreshold)
                break;

            newEnd = maxDistanceIndex;
        }

        start = newEnd;
    }

    // add the last point:
    pathFrames[firstInputIndex + decimatedFrameCounts[i]++] = pathFrames[firstInputIndex + end];

}

[numthreads(128, 1, 1)]
void ChaikinSmooth (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= chunkCount) return;

    int firstInputIndex = i > 0 ? frameOffsets[i - 1] : 0;
    int inputFrameCount = decimatedFrameCounts[i];

    int firstOutputIndex = smoothFrameOffsets[i];

    int k = (int)pathData[i].smoothing;

    // No work to do. just copy the input to the output:
    if (k == 0)
    {
        smoothFrameCounts[i] = inputFrameCount;
        for (int j = 0; j < inputFrameCount; ++j)
            smoothFrames[firstOutputIndex + j] = pathFrames[firstInputIndex + j];
    }
    else
    {
        // precalculate some quantities:
        int pCount = (int)pow(2, k);
        int n0 = inputFrameCount - 1;
        float twoRaisedToMinusKPlus1 = pow(2, -(k + 1));
        float twoRaisedToMinusK = pow(2, -k);
        float twoRaisedToMinus2K = pow(2, -2 * k);
        float twoRaisedToMinus2KMinus1 = pow(2, -2 * k - 1);

        smoothFrameCounts[i] = (inputFrameCount - 2) * pCount + 2;

        // calculate initial curve points:
        smoothFrames[firstOutputIndex] = addFrames(multiplyFrame(0.5f + twoRaisedToMinusKPlus1 , pathFrames[firstInputIndex]) , multiplyFrame(0.5f - twoRaisedToMinusKPlus1, pathFrames[firstInputIndex + 1]));
        smoothFrames[firstOutputIndex + pCount * n0 - pCount + 1] = addFrames(multiplyFrame(0.5f - twoRaisedToMinusKPlus1, pathFrames[firstInputIndex + n0 - 1]) , multiplyFrame(0.5f + twoRaisedToMinusKPlus1, pathFrames[firstInputIndex + n0]));

        // calculate internal points:
        for (int j = 1; j <= pCount; ++j)
        {
            // precalculate coefficients:
            float F = 0.5f - twoRaisedToMinusKPlus1 - (j - 1) * (twoRaisedToMinusK - j * twoRaisedToMinus2KMinus1);
            float G = 0.5f + twoRaisedToMinusKPlus1 + (j - 1) * (twoRaisedToMinusK - j * twoRaisedToMinus2K);
            float H = (j - 1) * j * twoRaisedToMinus2KMinus1;

            for (int l = 1; l < n0; ++l)
            {
                WeightedSum(F, G, H,
                            pathFrames[firstInputIndex + l - 1],
                            pathFrames[firstInputIndex + l],
                            pathFrames[firstInputIndex + l + 1],
                            smoothFrames[firstOutputIndex + (l - 1) * pCount + j]);
            }
        }

        // make first and last curve points coincide with original points:
        smoothFrames[firstOutputIndex] = pathFrames[firstInputIndex];
        smoothFrames[firstOutputIndex + smoothFrameCounts[i] - 1] = pathFrames[firstInputIndex + inputFrameCount - 1];
    }
    
    // calculate path lengths:
    pathData[i].smoothLength = 0;
    for (int j = firstOutputIndex + 1;  j < firstOutputIndex + smoothFrameCounts[i]; ++j)
        pathData[i].smoothLength += distance(smoothFrames[j-1].position, smoothFrames[j].position);
}
